package com.kool.kmqtt.server.repository.local;

import com.kool.kmqtt.server.ServerConfig;
import com.kool.kmqtt.server.exception.AppException;
import com.kool.kmqtt.server.exception.ErrorCode;
import com.kool.kmqtt.server.log.ClientNoAckSendPacketCnt;
import com.kool.kmqtt.server.packet.Packet;
import com.kool.kmqtt.server.repository.Repository;
import com.kool.kmqtt.server.repository.subscription.SubscribeInfo;
import com.kool.kmqtt.server.repository.subscription.Subscriber;
import com.kool.kmqtt.server.repository.subscription.Subscription;
import com.kool.kmqtt.server.session.WillStatus;
import com.kool.kmqtt.util.StringUtil;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.Collectors;

/**
 * 默认仓库
 */
public class DefaultRepository implements Repository {

    /**
     * 删除客户端的会话状态
     * 会话状态包括：订阅信息、遗嘱信息、出站未确认PUBLISH、PUBREL报文
     *
     * @param clientId
     */
    @Override
    public void deleteSessionStatus(String clientId) {
        //删除客户端的所有订阅信息
        deleteSubscriber(clientId);
        //删除遗嘱信息
        deleteWillStatus(clientId);
        //删除出站未确认PUBLISH、PUBREL报文
        deleteSendPackets(clientId);
    }
/////////////////////////////////////////////////////////////////////////////////////////////////////////
    /**
     * 客户端的遗嘱信息
     * key：clientId
     * value：遗嘱信息
     */
    private ConcurrentHashMap<String, WillStatus> willStatusMap = new ConcurrentHashMap<>();

    /**
     * 保存客户端遗嘱信息
     *
     * @param clientId
     * @param willStatus
     */
    @Override
    public void saveWillStatus(String clientId, WillStatus willStatus) {
        if (clientId == null) {
            throw new AppException(ErrorCode.SAVE_WILL_STATUS_CLIENT_ID_NULL);
        }
        willStatusMap.put(clientId, willStatus);
    }

    /**
     * 删除客户端的遗嘱信息
     *
     * @param clientId
     */
    @Override
    public void deleteWillStatus(String clientId) {
        if (clientId == null) {
            throw new AppException(ErrorCode.DELETE_WILL_STATUS_CLIENT_ID_NULL);
        }
        willStatusMap.remove(clientId);
    }

    /**
     * 查询客户端的遗嘱信息
     *
     * @param clientId
     * @return
     */
    @Override
    public WillStatus getWillStatus(String clientId) {
        return willStatusMap.get(clientId);
    }
/////////////////////////////////////////////////////////////////////////////////////////////////////////
    /**
     * 全局订阅信息
     */
    private SubscribeTrees subscribeTrees = SubscribeTrees.getInstance();

    /**
     * 查询所有订阅信息
     *
     * @return
     */
    @Override
    public List<Subscription> getSubscriptions() {
        return subscribeTrees.getAllCopy();
    }

    /**
     * 保存客户端的主题过滤器
     *
     * @param clientId
     * @param subscribeInfo
     */
    @Override
    public void saveSubscribeInfo(String clientId, SubscribeInfo subscribeInfo) {
        if (clientId == null) {
            throw new AppException(ErrorCode.CLIENT_ID_NULL);
        }
        if (subscribeInfo != null) {
            subscribeTrees.add(clientId, subscribeInfo);
        }
    }

    /**
     * 获取与主题匹配的主题过滤器对应的客户端信息
     *
     * @param topic
     * @return
     */
    @Override
    public List<Subscriber> getSubscriber(String topic) {
        if (topic == null) {
            throw new AppException(ErrorCode.TOPIC_NULL);
        }
        return subscribeTrees.match(topic);
    }

    /**
     * 删除客户端的一个订阅主题过滤器
     *
     * @param clientId
     * @param topicFilter
     */
    @Override
    public void deleteSubscribeInfo(String clientId, String topicFilter) {
        if (clientId != null && topicFilter != null) {
            subscribeTrees.delete(topicFilter, clientId);
        }
    }

    /**
     * 删除客户端的所有订阅信息
     *
     * @param clientId
     */
    @Override
    public void deleteSubscriber(String clientId) {
        if (clientId == null) {
            throw new AppException(ErrorCode.CLIENT_ID_NULL);
        }
        subscribeTrees.deleteSubscriber(clientId);
    }
/////////////////////////////////////////////////////////////////////////////////////////////////////////
    /**
     * 未确认的入站PUBLISH QoS1 QoS2报文和PUBREL报文
     * key = packetId
     * value = Packet
     */
    private ConcurrentHashMap<Integer, Packet> receivePacketMap = new ConcurrentHashMap<>();

    /**
     * 保存未确认的入站PUBLISH QoS1 QoS2报文和PUBREL报文
     *
     * @param clientId
     * @param packet
     */
    @Override
    public void saveReceivePacket(String clientId, Packet packet) {
        if (packet.getPacketId() == null) {
            throw new AppException(ErrorCode.SAVE_PACKET_ID_NULL);
        }
        //保存未确认的入站PUBLISH QoS1 QoS2报文和PUBREL报文
        receivePacketMap.put(packet.getPacketId(), packet);
    }

    /**
     * 获取未确认的入站PUBLISH QoS1 QoS2报文或PUBREL报文
     *
     * @param packetId
     * @return
     */
    @Override
    public Packet getReceivePacket(int packetId) {
        return receivePacketMap.get(packetId);
    }

    /**
     * 删除未确认的入站PUBLISH QoS1 QoS2报文或PUBREL报文
     *
     * @param packetId
     */
    @Override
    public void deleteReceivePacket(int packetId) {
        receivePacketMap.remove(packetId);
    }

    /////////////////////////////////////////////////////////////////////////////////////////////////////////
    /**
     * 保留消息树
     */
    private RetainTrees retainTrees = RetainTrees.getInstance();

    /**
     * 保存保留消息
     *
     * @param topicName
     * @param packet
     */
    @Override
    public void saveRetainPacket(String topicName, Packet packet) {
        retainTrees.add(topicName, packet);
    }

    /**
     * 获取主题过滤器匹配的保留消息
     *
     * @param topicFilter
     * @return
     */
    @Override
    public List<Packet> getRetainPacket(String topicFilter) {
        return retainTrees.match(topicFilter);
    }

    /**
     * 清空主题下的保留消息
     *
     * @param topicName
     */
    @Override
    public void deleteRetainPackets(String topicName) {
        retainTrees.delete(topicName);
    }

/////////////////////////////////////////////////////////////////////////////////////////////////////////
    /**
     * 客户的未确认的出站PUBLISH或PUBREL报文
     * key:clientId
     * value:未确认的出站PUBLISH或PUBREL报文
     */
    private ConcurrentHashMap<String, ConcurrentHashMap<Integer, Packet>> sendPacketMap = new ConcurrentHashMap<>();
    private final ReentrantLock sendPacketMapLock = new ReentrantLock();

    private void initClientSendPackets(String clientId) {
        sendPacketMapLock.lock();
        try {
            if (sendPacketMap.get(clientId) == null) {
                sendPacketMap.put(clientId, new ConcurrentHashMap<>());
            }
        } catch (Exception e) {
            throw e;
        } finally {
            sendPacketMapLock.unlock();
        }
    }

    /**
     * 保存未确认的出站PUBLISH或PUBREL报文
     *
     * @param clientId
     * @param packet
     */
    @Override
    public void saveSendPackets(String clientId, Packet packet) {
        if (sendPacketMap.get(clientId) == null) {
            initClientSendPackets(clientId);
        }
        ConcurrentHashMap<Integer, Packet> sendPackets = sendPacketMap.get(clientId);
        sendPackets.put(packet.getPacketId(), packet);
    }

    /**
     * 删除出站未确认的出站PUBLISH或PUBREL报文
     *
     * @param clientId
     */
    @Override
    public void deleteSendPackets(String clientId) {
        if (clientId == null) {
            throw new AppException(ErrorCode.CLIENT_ID_NULL);
        }
        sendPacketMap.remove(clientId);
    }

    /**
     * 查询出站未确认的出站PUBLISH或PUBREL报文
     *
     * @param clientId
     * @return
     */
    @Override
    public List<Packet> getSendPackets(String clientId) {
        if (sendPacketMap.get(clientId) == null) {
            initClientSendPackets(clientId);
        }
        //转报文时按报文发送时间时间戳倒序排序
        return new ArrayList<>(sendPacketMap.get(clientId).values()).stream()
                .sorted(Comparator.comparing(Packet::getSendTime).reversed())
                .collect(Collectors.toList());
    }

    /**
     * 删除1条出站未确认的出站PUBLISH或PUBREL报文
     *
     * @param clientId
     * @param packetId
     */
    @Override
    public void deleteSendPacket(String clientId, int packetId) {
        if (StringUtil.isEmpty(clientId)) {
            return;
        }
        if (sendPacketMap.get(clientId) == null) {
            initClientSendPackets(clientId);
        }
        ConcurrentHashMap<Integer, Packet> packets = sendPacketMap.get(clientId);
        packets.remove(packetId);
    }

    /**
     * 统计所有客户端出站未确认消息数
     *
     * @return
     */
    @Override
    public List<ClientNoAckSendPacketCnt> countSendPacket() {
        List<ClientNoAckSendPacketCnt> logs = new ArrayList<>();
        for (String clientId : sendPacketMap.keySet()) {
            ClientNoAckSendPacketCnt log = new ClientNoAckSendPacketCnt();
            log.setClientId(clientId);
            Map<Integer, Packet> packetMap = sendPacketMap.get(clientId);
            if (packetMap != null) {
                log.setCnt(packetMap.size());
            } else {
                log.setCnt(0);
            }
            logs.add(log);
        }
        return logs;
    }

    /**
     * 熔断计数器表
     */
    ConcurrentHashMap<String, FailoverCounter> failoverCounterMap = new ConcurrentHashMap<>();

    @Override
    public Integer getClientConnectErrorTimes(String remoteAddress) {
        if (remoteAddress == null) {
            //取不到地址就不处理
            return null;
        }
        FailoverCounter counter = failoverCounterMap.get(remoteAddress);
        if (counter == null || counter.getCounter() == null) {
            return 0;
        } else {
            long expireTime = counter.getExpireTime();
            //如果超时，删除计数器，返回错误次数0
            if (System.currentTimeMillis() > expireTime) {
                failoverCounterMap.remove(remoteAddress);
                return 0;
            }
            return counter.getCounter().get();
        }
    }

    @Override
    public void increaseFailoverCounter(String remoteAddress, int timeoutSeconds) {
        if (remoteAddress == null) {
            //取不到地址就不处理
            return;
        }
        FailoverCounter counter = failoverCounterMap.get(remoteAddress);
        long expireTime = System.currentTimeMillis() + ServerConfig.getInstance().getFailoverCounterResetSeconds() * 1000L;
        if (counter == null || counter.getCounter() == null) {
            counter = new FailoverCounter();
            counter.setCounter(new AtomicInteger(1));
            counter.setExpireTime(expireTime);
            //不考虑同一个客户端并发连接导致数据问题，例如同时2个客户端连接失败触发下1行代码并发执行可能计数器最终值为1而不是2
            failoverCounterMap.put(remoteAddress, counter);
        } else {
            counter.getCounter().incrementAndGet();
        }
    }
}
